import argparse
import logging
import numpy as np
import torch
from gym.wrappers import RecordVideo
import cv2
from PIL import Image, ImageDraw, ImageFont
from hyperparameters import get_hyperparameters
from common import make_env, create_folders
from sac import SAC, SACGRU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def record(env_name, seed_gru, seed_sac, steps, train_steps):
    augment_type = "SAC"
    hy = get_hyperparameters(env_name, 'SAC')
    arguments = [augment_type, env_name, seed_sac, True, "best"]
    file_name = '_'.join([str(x) for x in arguments])
    env = make_env(env_name, 100)
    eval_env = RecordVideo(env, './'+env_name, name_prefix="SAC", episode_trigger=lambda x: x < 10)
    kwargs = {
        "gamma": hy['discount'],
        "tau": hy['tau'],
        "alpha": hy['alpha'],
        "policy_type": "Gaussian",
        "hidden_size": hy['hidden_size'],
        "target_update_interval": hy['target_update_interval'],
        "automatic_entropy_tuning": True,
        "lr": hy['lr'],
    }
    state_dim = eval_env.observation_space.shape[0]
    policy = SAC(state_dim, env.action_space, **kwargs)
    policy.load_checkpoint(f"./models/{file_name}")

    for episode in range(10):
        eval_state, eval_done = eval_env.reset(), False
        episode_reward = 0
        while not eval_done:
            eval_action = policy.select_action(eval_state, evaluate=True)
            eval_next_state, eval_reward, eval_done, _ = eval_env.step(eval_action)
            eval_state = eval_next_state
            episode_reward += eval_reward
            if eval_done:
                break

    eval_env.close()

    for episode in range(10):

        video_path = env_name + '/SAC-episode-' + str(episode) + '.mp4'  # Replace with your video file
        cap = cv2.VideoCapture(video_path)

        # Get video information
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Define the codec and create a VideoWriter object to save output
        output_path = env_name + '/SAC-episode-' + str(episode) + '_modified.mp4'
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

        # Specify the symbol and interval
        symbol = "★"  # The symbol to insert
        interval = 1

        # Load a font
        font_path = "DejaVuSans-Bold.ttf"  # Update this with the path to a TTF font on your system
        font = ImageFont.truetype(font_path, 40)

        frame_count = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Convert the frame from OpenCV's BGR format to RGB format
            frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

            if frame_count % interval == 0:
                # Draw the symbol on the frame
                draw = ImageDraw.Draw(frame_pil)
                draw.text((50, 50), symbol, font=font, fill=(242, 206, 22))  # Position the symbol (50, 50), red color

            # Convert the frame back to OpenCV's BGR format
            frame = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR)

            # Write the frame into the output video
            out.write(frame)

            frame_count += 1

        # Release everything
        cap.release()
        out.release()

        print("Video processing completed!")

    env = make_env(env_name, 100)
    eval_env = RecordVideo(env, './' + env_name, name_prefix="GRU", episode_trigger=lambda x: x < 10)
    state_dim = eval_env.observation_space.shape[0]
    kwargs = {
        "gamma": hy['discount'],
        "tau": hy['tau'],
        "alpha": hy['alpha'],
        "policy_type": "Gaussian",
        "hidden_size": hy['hidden_size'],
        "target_update_interval": hy['target_update_interval'],
        "automatic_entropy_tuning": True,
        "lr": hy['lr'],
    }
    kwargs["steps"] = train_steps

    policy = SACGRU(state_dim, eval_env.action_space, **kwargs)
    file_name = f"SAC_GRU_{env_name}_{seed_gru}_True_{train_steps}_4_best"
    policy.load_checkpoint(f"./models/{file_name}")
    action_dim = eval_env.action_space.shape[0]
    for episode in range(10):
        eval_state, eval_done = eval_env.reset(), False
        episode_reward = 0
        eval_prev_action = torch.zeros(action_dim)
        while not eval_done:
            eval_actions = policy.select_action(eval_state, eval_prev_action, steps, evaluate=True)
            for step in range(steps):
                eval_action = eval_actions[step]
                eval_next_state, eval_reward, eval_done, _ = eval_env.step(eval_action)
                eval_state = eval_next_state
                eval_prev_action = eval_action
                episode_reward += eval_reward
                if eval_done:
                    break

    eval_env.close()

    for episode in range(10):

        video_path = env_name + '/GRU-episode-' + str(episode) + '.mp4'  # Replace with your video file
        cap = cv2.VideoCapture(video_path)

        # Get video information
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Define the codec and create a VideoWriter object to save output
        output_path = env_name + '/GRU-episode-' + str(episode) + '_modified.mp4'
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

        # Specify the symbol and interval
        symbol = "★"  # The symbol to insert
        interval = steps

        # Load a font
        font_path = "DejaVuSans-Bold.ttf"  # Update this with the path to a TTF font on your system
        font = ImageFont.truetype(font_path, 40)

        frame_count = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Convert the frame from OpenCV's BGR format to RGB format
            frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

            if frame_count % interval == 0:
                # Draw the symbol on the frame
                draw = ImageDraw.Draw(frame_pil)
                draw.text((50, 50), symbol, font=font, fill=(242, 206, 22))  # Position the symbol (50, 50), red color

            # Convert the frame back to OpenCV's BGR format
            frame = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR)

            # Write the frame into the output video
            out.write(frame)

            frame_count += 1

        # Release everything
        cap.release()
        out.release()

        print("Video processing completed!")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", default="LunarLanderContinuous-v2", help="Environment name")
    parser.add_argument("--seed_gru", default=0, type=int, help="Sets Gym, PyTorch and Numpy seeds")
    parser.add_argument("--seed_sac", default=0, type=int, help="Sets Gym, PyTorch and Numpy seeds")

    parser.add_argument("--steps", default=2, type=int, help="Number of steps to plan ahead")
    parser.add_argument("--train_steps", default=2, type=int, help="Number of steps to plan ahead")


    args = vars(parser.parse_args())
    logging.info('Command-line argument values:')
    for key, value in args.items():
        logging.info(f'- {key} : {value}')

    record(**args)
